//
// Created by Xing on 2022/3/26.
//

#include <vector>
#include "caffe/layers/expand_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

    template<typename Dtype>
    void ExpandLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        ExpandParameter param = this->layer_param().expand_param();
        const BlobShape& shape = param.shape();
        CHECK_EQ(bottom.size(), 1);
        CHECK_EQ(top.size(), 1);
        shape_.clear();
        if (shape.dim(0) == 0 || shape.dim(0) == -1) {
            shape_.template emplace_back(bottom[0]->shape(0));
        } else {
            shape_.template emplace_back(shape.dim(0));
        }
        for (int i=1; i<shape.dim_size(); ++i) {
            shape_.emplace_back(shape.dim(i));
        }
    }

    template<typename Dtype>
    void ExpandLayer<Dtype>::Reshape(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        CHECK_EQ(shape_[0], bottom[0]->shape(0)); // batch维度必须相当
        CHECK_EQ(shape_.size(), bottom[0]->shape().size());
        for (int i=1; i<shape_.size(); ++i) {
            if (bottom[0]->shape(i) == 1)
                continue;
            CHECK_EQ(shape_[i], bottom[0]->shape(i));
        }
        top[0]->Reshape(shape_);
    }

    template<typename Dtype>
    void ExpandLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        Dtype *top_data = top[0]->mutable_cpu_data();
        const Dtype *bottom_data = bottom[0]->cpu_data();
        int N = bottom[0]->shape(0);
        int C = shape_[1];
        int H = shape_.size() > 2 ? shape_[2] : 1;
        int W = shape_.size() > 3 ? shape_[3] : 1;
        int CHW = C*H*W;
        int HW = H*W;

        int in_C = bottom[0]->shape(1);
        int in_H = bottom[0]->num_axes() > 2 ? bottom[0]->shape(2) : 1;
        int in_W = bottom[0]->num_axes() > 3 ? bottom[0]->shape(3) : 1;
        int in_CHW = bottom[0]->count(1);
        int in_HW = bottom[0]->count(2);

        for (int n=0; n<N; ++n) {
            for (int c=0; c<C; ++c) {
                int ic = in_C == 1 ? 0 : c;
                for (int h=0; h<H; ++h) {
                    int ih = in_H == 1 ? 0 : h;
                    for (int w=0; w<W; ++w) {
                        int iw = in_W == 1 ? 0 : w;
                        top_data[n*CHW + c*HW + h*W + w] = bottom_data[n*in_CHW + ic*in_HW + ih*in_W + iw];
                    }
                }
            }
        }
    }

    INSTANTIATE_CLASS(ExpandLayer);
    REGISTER_LAYER_CLASS(Expand);
}  // namespace caffe